import random
from itertools import zip_longest
from typing import Tuple, List

import numpy as np
import pandas as pd
import scipy

from algos import local_sgd, minibatch_sgd, local_sgd_momentum
from config import Config

import functools
import os
import pickle
from concurrent.futures import ProcessPoolExecutor
import glob
from dateutil.relativedelta import relativedelta


preprocessed_data_path = "../data/preprocessed"
results_dir = "../results_time_series"

seed_value = 0
repeat = 10


def generate_random_seed():
    global seed_value
    seed_value +=1
    return seed_value




def least_square(
    test_data: np.ndarray,
    test_label: np.ndarray,
) -> np.ndarray:
    """
    Find the least square minimizer on the test data
    :param test_data: test data of shape (n_stations, n_test_samples, data_dim)
    :param test_label: test label of shape (n_stations, n_test_samples, 1)
    :return: the least square minimizer of shape (data_dim,)
    """
    data_flatten = np.reshape(test_data, (-1, test_data.shape[-1]))
    label_flatten = np.reshape(test_label, (-1, test_label.shape[-1]))
    try:
        w = np.linalg.solve(
            data_flatten,
            label_flatten,
        )
    except np.linalg.LinAlgError:
        w = np.linalg.lstsq(data_flatten, label_flatten)[0]

    return w.squeeze(axis=-1)


def seed_map(
    t
):
    """
    Apply run_algo (t[0]) with the seed (t[1])
    :param t:
    :return:
    """
    return t[0](t[1])


def run_algo(
    seed: int,
    algo,
    config: Config,
    train_dfs: List[pd.DataFrame],
    test_data: np.ndarray,
    test_label: np.ndarray,
    w_0: np.ndarray,
    n_clients: int,
    n_months: int
):
    # Initialize the seed
    rng = np.random.default_rng(seed)

    # Initialize deviation of the starting model
    initialization = rng.normal(loc=0.1, scale=0.05, size=w_0.shape)

    # Build train features and labels
    train_data, train_label = query_train_data(
        train_dfs=train_dfs,
        n_clients=n_clients,
        n_months=n_months,
        rng=rng
    )

    return (
        algo.__name__,
        algo(
            config=config,
            train_data=train_data,
            train_label=train_label,
            test_data=test_data,
            test_label=test_label,
            w_0=w_0+initialization,
        )
    )


def query_monthly_df(
    df: pd.DataFrame,
    date_range: pd.DatetimeIndex,
) -> List[pd.DataFrame]:
    """
    Query monthly data and put in separated dfs
    :param df: the df to query from
    :param date_range: the date range to query from
    :return: list of monthly dfs
    """
    monthly_dfs = []
    for date in date_range:
        monthly_df = df[date:date+pd.offsets.MonthEnd(0)]
        monthly_dfs.append(monthly_df)

    return monthly_dfs


def train_test_split_df(
    config: Config,
) -> Tuple[List[pd.DataFrame], List[pd.DataFrame]]:
    """
    Load the preprocessed data and split into train and test dfs
    :param config: configuration
    :return:
        train_dfs: List of train dfs, each element corresponds to a station
        test_dfs: List of test dfs, each element corresponds to a station
    """
    # Load preprocessed dfs
    all_preprocessed_files = glob.glob(os.path.join(preprocessed_data_path, "*.csv"))
    dfs = [
        pd.read_csv(file, index_col="Datetime", header=0, parse_dates=["Datetime"])
        for file in all_preprocessed_files
    ]

    # Train-test split
    last_date = dfs[0].index.max()
    test_start_date = (
            last_date.replace(day=1) - relativedelta(months=config.test_size - 1)
    ).normalize()
    mask = (dfs[0].index >= test_start_date) & (dfs[0].index <= last_date)
    train_dfs = [
        df[~mask] for df in dfs
    ]
    test_dfs = [
        df[mask] for df in dfs
    ]

    return train_dfs, test_dfs


def query_train_data(
    train_dfs: List[pd.DataFrame],
    n_clients: int,
    n_months: int,
    rng: np.random.Generator,
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    """
    Query the monthly training data and build the corresponding
    training features and labels for each number of clients in
    n_clients_list
    :param n_months: number of months each client has access to.
    :param rng: random generator
    :param n_clients: number of clients
    :param train_dfs: list of training dfs
    :return:
        training_data: List of training features, each element
        corresponds to a number of clients
        training_labels: List of training labels, each element
        corresponds to a number of clients
    """
    # get number of stations
    n_stations = len(train_dfs)

    # Query monthly data
    train_last_date = train_dfs[0].index.max()
    train_start_date = train_dfs[0].index.min()
    date_range = pd.date_range(start=train_start_date, end=train_last_date, freq="MS")

    train_monthly_dfs = []
    for train_df in train_dfs:
        train_monthly_dfs.append(
            query_monthly_df(train_df, date_range)
        )
    # train_monthly_dfs: list of 12 stations data, where each element in the list
    # is the list of the data of a particular station split into monthly data

    data = []
    label = []
    if n_clients >= n_stations:
        n_to_choose = n_clients // n_stations
        for i in range(n_stations):
            chosen_idx = rng.choice(
                a=len(train_monthly_dfs[i])-n_months,
                size=n_to_choose,
                replace=False,
            )
            for indice in chosen_idx:
                data.append(np.vstack([
                    train_monthly_dfs[i][indice+j].drop([
                        "pm2.5", "station"
                    ], axis=1).values
                    for j in range(n_months)
                ]))
                label.append(np.vstack([
                    train_monthly_dfs[i][indice+j]["pm2.5"].values[..., np.newaxis]
                    for j in range(n_months)
                ]))
    else:
        chosen_stations = rng.choice(
            a=n_stations,
            size=n_clients,
            replace=False,
        )
        for station in chosen_stations:
            chosen_started_month = rng.choice(
                a=len(train_monthly_dfs[station])-n_months,
                size=1
            )[0]
            data.append(np.vstack([
                train_monthly_dfs[station][chosen_started_month+j].drop([
                    "pm2.5", "station"
                ], axis=1).values
                for j in range(n_months)
            ]))
            label.append(np.vstack([
                train_monthly_dfs[station][chosen_started_month+j]["pm2.5"].values[..., np.newaxis]
                for j in range(n_months)
            ]))


    min_samples = min(x.shape[0] for x in data)
    training_data = np.stack([
        x[:min_samples]
        for x in data
    ], axis=0)
    training_labels = np.stack([
        x[:min_samples]
        for x in label
    ], axis=0)

    return training_data, training_labels


def train_test_split(
    config: Config
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Read the preprocessed dfs stored in the csv files and split to
    training and test sets.
    :param config: configuration
    :return:
        train_data: training dfs of shape (n_stations, n_training_samples, data_dim)
        train_label: training label of shape (n_stations, n_training_samples, 1)
        test_data: test dfs of shape (n_stations, n_test_samples, data_dim)
        test_label: test label of shape (n_stations, n_test_samples, 1)
    """
    # Load preprocessed dfs
    all_preprocessed_files = glob.glob(os.path.join(preprocessed_data_path, "*.csv"))
    dfs = [
        pd.read_csv(file, index_col="Datetime", header=0, parse_dates=["Datetime"])
        for file in all_preprocessed_files
    ]

    # Train-test split
    last_date = dfs[0].index.max()
    test_start_date = (
        last_date.replace(day=1) - relativedelta(months=config.test_size-1)
    ).normalize()
    mask = (dfs[0].index >= test_start_date) & (dfs[0].index <= last_date)
    train_dfs = [
        df[~mask] for df in dfs
    ]
    test_dfs = [
        df[mask] for df in dfs
    ]

    train_data = np.stack([
        train_df.drop(["pm2.5", "station"], axis=1).values
        for train_df in train_dfs
    ], axis=0)
    train_label = np.stack([
        train_df["pm2.5"].values[..., np.newaxis]
        for train_df in train_dfs
    ], axis=0)

    test_data = np.stack([
        test_df.drop(["pm2.5", "station"], axis=1).values
        for test_df in test_dfs
    ], axis=0)
    test_label = np.stack([
        test_df["pm2.5"].values[..., np.newaxis]
        for test_df in test_dfs
    ], axis=0)

    return train_data, train_label, test_data, test_label


def run_experiment(
    n_months: int,
    global_lr: float,
    local_lr: float,
    momentum_coef: float,
    local_steps: int,
    regularization: bool,
    n_clients: int,
    test_size: int = 5,
):
    # Create the results dir
    os.makedirs(results_dir, exist_ok=True)

    # Initialize the configuration
    config = Config(
        global_lr=global_lr,
        local_lr=local_lr,
        momentum_coef=momentum_coef,
        local_steps=local_steps,
        regularization=regularization,
        test_size=test_size,
    )

    # Load preprocessed data as dfs
    train_dfs, test_dfs = train_test_split_df(config)

    # Build test features and labels
    test_data = np.stack([
        test_df.drop(["pm2.5", "station"], axis=1).values
        for test_df in test_dfs
    ], axis=0)
    test_label = np.stack([
        test_df["pm2.5"].values[..., np.newaxis]
        for test_df in test_dfs
    ], axis=0)

    # Initialize the global model
    w_0 = least_square(test_data, test_label)

    # Run
    print(
        f"---- M = {n_clients}, K = {config.local_steps}, local_lr = {config.local_lr}, global_lr = {config.global_lr}, momentum = {config.momentum_coef}, regularization = {config.regularization}----")

    func_local_sgd = functools.partial(
        run_algo,
        algo=local_sgd,
        config=config,
        train_dfs=train_dfs,
        test_data=test_data,
        test_label=test_label,
        w_0=w_0,
        n_clients=n_clients,
        n_months=n_months,
    )
    func_minibatch_sgd = functools.partial(
        run_algo,
        algo=minibatch_sgd,
        config=config,
        train_dfs=train_dfs,
        test_data=test_data,
        test_label=test_label,
        w_0=w_0,
        n_clients=n_clients,
        n_months=n_months,
    )
    func_local_sgd_momentum = functools.partial(
        run_algo,
        algo=local_sgd_momentum,
        config=config,
        train_dfs=train_dfs,
        test_data=test_data,
        test_label=test_label,
        w_0=w_0,
        n_clients=n_clients,
        n_months=n_months,
    )
    with ProcessPoolExecutor(max_workers=12) as executor:
        res = list(executor.map(
            seed_map,
            sum([
                [(func_local_sgd, generate_random_seed()),
                 (func_minibatch_sgd, generate_random_seed()),
                 (func_local_sgd_momentum, generate_random_seed())]
                for _ in range(repeat)
            ], [])
        ))

    # Read the results
    results_local_sgd = [r[1] for r in res if r[0] == "local_sgd"]
    results_minibatch_sgd = [r[1] for r in res if r[0] == "minibatch_sgd"]
    results_local_sgd_momentum = [r[1] for r in res if r[0] == "local_sgd_momentum"]

    loss_local_sgd, grad_norm_local_sgd = [], []
    loss_minibatch_sgd, grad_norm_minibatch_sgd = [], []
    loss_local_sgd_momentum, grad_norm_local_sgd_momentum = [], []

    for i in range(repeat):
        loss_local_sgd.append(results_local_sgd[i][0])
        grad_norm_local_sgd.append(results_local_sgd[i][1])

        loss_minibatch_sgd.append(results_minibatch_sgd[i][0])
        grad_norm_minibatch_sgd.append(results_minibatch_sgd[i][1])

        loss_local_sgd_momentum.append(results_local_sgd_momentum[i][0])
        grad_norm_local_sgd_momentum.append(results_local_sgd_momentum[i][1])

    # Save the results
    file_name = f"local_lr={config.local_lr},global_lr={config.global_lr},momentum={config.momentum_coef},local_steps={config.local_steps},n_clients={n_clients},n_months={n_months},regularization={config.regularization}"
    with open(os.path.join(results_dir, f"{file_name}.pkl"), "wb") as f:
        pickle.dump({
            "config": config,
            "local_sgd": {
                "loss": np.array(
                    list(zip_longest(*loss_local_sgd, fillvalue=0))
                ).T,
                "grad_norm": np.array(
                    list(zip_longest(*grad_norm_local_sgd, fillvalue=0))
                ).T,
            },
            "minibatch_sgd": {
                "loss": np.array(
                    list(zip_longest(*loss_minibatch_sgd, fillvalue=0))
                ).T,
                "grad_norm": np.array(
                    list(zip_longest(*grad_norm_minibatch_sgd, fillvalue=0))
                ).T,
            },
            "local_sgd_momentum": {
                "loss": np.array(
                    list(zip_longest(*loss_local_sgd_momentum, fillvalue=0))
                ).T,
                "grad_norm": np.array(
                    list(zip_longest(*grad_norm_local_sgd_momentum, fillvalue=0))
                ).T,
            }
        }, f)


def main():
    for n_months in [6, 12]:
        for n_clients in [6, 24, 240]:
            for local_lr in [0.01, 0.001]:
                for local_steps in [10, 50, 100]:
                    run_experiment(n_months=n_months, global_lr=0.1, local_lr=local_lr, momentum_coef=0.5,
                                   local_steps=local_steps, regularization=True, n_clients=n_clients, test_size=12)


if __name__ == "__main__":
    main()
